{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c4eab109-d2c0-4409-936e-94f07050816e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,DataCollatorForSeq2Seq\n",
    "from peft import get_peft_model, LoraConfig,TaskType\n",
    "from peft.utils import prepare_model_for_kbit_training\n",
    "import torch\n",
    "from datasets import Dataset\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "37299903-fb62-49d9-ab37-a1d17ff74918",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_json('../datasets/cerbo.json')\n",
    "ds = Dataset.from_pandas(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "335f19ac-5320-4696-bbe2-5afd5210267c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['instruction', 'input', 'output'],\n",
       "    num_rows: 91\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5b841c2-115d-4d8d-ae5a-b7a315119e2b",
   "metadata": {},
   "source": [
    "## Processing the training dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "894c4ef3-36bb-4cd0-b50b-3b119296379d",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"google/gemma-2-2b-it\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n",
    "tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "tokenizer.padding_side = 'right'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e3055bab-07f7-470b-bbdc-fe777a4d4e19",
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_func(example):\n",
    "    MAX_LENGTH = 384\n",
    "    input_ids, attention_mask, labels = [], [], []\n",
    "    instruction = tokenizer(f\"<bos><start_of_turn>user\\n{example['instruction'] + example['input']}<end_of_turn>\\n<start_of_turn>model\\n\", add_special_tokens=False)\n",
    "    response = tokenizer(f\"{example['output']}<end_of_turn>\\n\", add_special_tokens=False)\n",
    "    input_ids = instruction[\"input_ids\"] + response[\"input_ids\"] + [tokenizer.pad_token_id]\n",
    "    attention_mask = instruction[\"attention_mask\"] + response[\"attention_mask\"] + [1] \n",
    "    labels = [-100] * len(instruction[\"input_ids\"]) + response[\"input_ids\"] + [tokenizer.pad_token_id]  \n",
    "    if len(input_ids) > MAX_LENGTH: \n",
    "        input_ids = input_ids[:MAX_LENGTH]\n",
    "        attention_mask = attention_mask[:MAX_LENGTH]\n",
    "        labels = labels[:MAX_LENGTH]\n",
    "    return {\n",
    "        \"input_ids\": input_ids,\n",
    "        \"attention_mask\": attention_mask,\n",
    "        \"labels\": labels\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ac487644-7bbc-4160-96f8-186cbb138055",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "39a366ffba544b4fb7a06f88c6195955",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/91 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['input_ids', 'attention_mask', 'labels'],\n",
       "    num_rows: 91\n",
       "})"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenized_id = ds.map(process_func, remove_columns=ds.column_names)\n",
    "tokenized_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "330628cc-6ef1-4d1c-a3ca-24a64e01866b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<bos><start_of_turn>user\n",
      "hi<end_of_turn>\n",
      "<start_of_turn>model\n",
      "Hello! How can I assist you today?<end_of_turn>\n",
      "<eos>\n"
     ]
    }
   ],
   "source": [
    "print(tokenizer.decode(tokenized_id[0]['input_ids']))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76017971-acca-43be-a51d-a96b8d54ae0e",
   "metadata": {},
   "source": [
    "## Create model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cd537edd-d959-4964-9c84-d4b4bbf7a196",
   "metadata": {},
   "outputs": [],
   "source": [
    "quantization_config = BitsAndBytesConfig(load_in_8bit=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9237ee9e-b956-44f0-b28a-b62556403aad",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dcd8d4f9980b429f9aec0d3aeecdecb1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = model = AutoModelForCausalLM.from_pretrained(\n",
    "        model_path,\n",
    "        quantization_config=quantization_config,\n",
    "        device_map=\"auto\",\n",
    "        trust_remote_code=True,\n",
    "        torch_dtype=torch.bfloat16,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "eeb18420-554b-483b-8244-1186b4e358e2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Gemma2ForCausalLM(\n",
       "  (model): Gemma2Model(\n",
       "    (embed_tokens): Embedding(256000, 2304, padding_idx=0)\n",
       "    (layers): ModuleList(\n",
       "      (0-25): 26 x Gemma2DecoderLayer(\n",
       "        (self_attn): Gemma2SdpaAttention(\n",
       "          (q_proj): Linear8bitLt(in_features=2304, out_features=2048, bias=False)\n",
       "          (k_proj): Linear8bitLt(in_features=2304, out_features=1024, bias=False)\n",
       "          (v_proj): Linear8bitLt(in_features=2304, out_features=1024, bias=False)\n",
       "          (o_proj): Linear8bitLt(in_features=2048, out_features=2304, bias=False)\n",
       "          (rotary_emb): Gemma2RotaryEmbedding()\n",
       "        )\n",
       "        (mlp): Gemma2MLP(\n",
       "          (gate_proj): Linear8bitLt(in_features=2304, out_features=9216, bias=False)\n",
       "          (up_proj): Linear8bitLt(in_features=2304, out_features=9216, bias=False)\n",
       "          (down_proj): Linear8bitLt(in_features=9216, out_features=2304, bias=False)\n",
       "          (act_fn): PytorchGELUTanh()\n",
       "        )\n",
       "        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
       "        (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
       "        (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
       "        (post_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
       "      )\n",
       "    )\n",
       "    (norm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=2304, out_features=256000, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "699e23ee-f508-4bf4-850f-5367abd61ee2",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.enable_input_require_grads()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e49a651-9de4-4911-beef-1dee87607828",
   "metadata": {},
   "source": [
    "## Lora"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3713dd42-7f7d-48cf-883a-8fb0a7f67a1a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LoraConfig(peft_type=<PeftType.LORA: 'LORA'>, auto_mapping=None, base_model_name_or_path=None, revision=None, task_type=<TaskType.CAUSAL_LM: 'CAUSAL_LM'>, inference_mode=False, r=8, target_modules={'k_proj', 'v_proj', 'q_proj', 'up_proj', 'gate_proj', 'o_proj', 'down_proj'}, lora_alpha=16, lora_dropout=0.055, fan_in_fan_out=False, bias='none', use_rslora=False, modules_to_save=None, init_lora_weights=True, layers_to_transform=None, layers_pattern=None, rank_pattern={}, alpha_pattern={}, megatron_config=None, megatron_core='megatron.core', loftq_config={}, use_dora=False, layer_replication=None)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "config = LoraConfig(\n",
    "    task_type=TaskType.CAUSAL_LM, \n",
    "    target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", 'gate_proj', 'up_proj', 'down_proj'],\n",
    "    inference_mode=False,\n",
    "    r=8, \n",
    "    lora_alpha=16, \n",
    "    lora_dropout=0.055,\n",
    "    bias=\"none\",\n",
    ")\n",
    "config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "950035e4-db83-4eda-82e4-2eff1b185cfa",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PeftModelForCausalLM(\n",
       "  (base_model): LoraModel(\n",
       "    (model): Gemma2ForCausalLM(\n",
       "      (model): Gemma2Model(\n",
       "        (embed_tokens): Embedding(256000, 2304, padding_idx=0)\n",
       "        (layers): ModuleList(\n",
       "          (0-25): 26 x Gemma2DecoderLayer(\n",
       "            (self_attn): Gemma2SdpaAttention(\n",
       "              (q_proj): lora.Linear8bitLt(\n",
       "                (base_layer): Linear8bitLt(in_features=2304, out_features=2048, bias=False)\n",
       "                (lora_dropout): ModuleDict(\n",
       "                  (default): Dropout(p=0.055, inplace=False)\n",
       "                )\n",
       "                (lora_A): ModuleDict(\n",
       "                  (default): Linear(in_features=2304, out_features=8, bias=False)\n",
       "                )\n",
       "                (lora_B): ModuleDict(\n",
       "                  (default): Linear(in_features=8, out_features=2048, bias=False)\n",
       "                )\n",
       "                (lora_embedding_A): ParameterDict()\n",
       "                (lora_embedding_B): ParameterDict()\n",
       "              )\n",
       "              (k_proj): lora.Linear8bitLt(\n",
       "                (base_layer): Linear8bitLt(in_features=2304, out_features=1024, bias=False)\n",
       "                (lora_dropout): ModuleDict(\n",
       "                  (default): Dropout(p=0.055, inplace=False)\n",
       "                )\n",
       "                (lora_A): ModuleDict(\n",
       "                  (default): Linear(in_features=2304, out_features=8, bias=False)\n",
       "                )\n",
       "                (lora_B): ModuleDict(\n",
       "                  (default): Linear(in_features=8, out_features=1024, bias=False)\n",
       "                )\n",
       "                (lora_embedding_A): ParameterDict()\n",
       "                (lora_embedding_B): ParameterDict()\n",
       "              )\n",
       "              (v_proj): lora.Linear8bitLt(\n",
       "                (base_layer): Linear8bitLt(in_features=2304, out_features=1024, bias=False)\n",
       "                (lora_dropout): ModuleDict(\n",
       "                  (default): Dropout(p=0.055, inplace=False)\n",
       "                )\n",
       "                (lora_A): ModuleDict(\n",
       "                  (default): Linear(in_features=2304, out_features=8, bias=False)\n",
       "                )\n",
       "                (lora_B): ModuleDict(\n",
       "                  (default): Linear(in_features=8, out_features=1024, bias=False)\n",
       "                )\n",
       "                (lora_embedding_A): ParameterDict()\n",
       "                (lora_embedding_B): ParameterDict()\n",
       "              )\n",
       "              (o_proj): lora.Linear8bitLt(\n",
       "                (base_layer): Linear8bitLt(in_features=2048, out_features=2304, bias=False)\n",
       "                (lora_dropout): ModuleDict(\n",
       "                  (default): Dropout(p=0.055, inplace=False)\n",
       "                )\n",
       "                (lora_A): ModuleDict(\n",
       "                  (default): Linear(in_features=2048, out_features=8, bias=False)\n",
       "                )\n",
       "                (lora_B): ModuleDict(\n",
       "                  (default): Linear(in_features=8, out_features=2304, bias=False)\n",
       "                )\n",
       "                (lora_embedding_A): ParameterDict()\n",
       "                (lora_embedding_B): ParameterDict()\n",
       "              )\n",
       "              (rotary_emb): Gemma2RotaryEmbedding()\n",
       "            )\n",
       "            (mlp): Gemma2MLP(\n",
       "              (gate_proj): lora.Linear8bitLt(\n",
       "                (base_layer): Linear8bitLt(in_features=2304, out_features=9216, bias=False)\n",
       "                (lora_dropout): ModuleDict(\n",
       "                  (default): Dropout(p=0.055, inplace=False)\n",
       "                )\n",
       "                (lora_A): ModuleDict(\n",
       "                  (default): Linear(in_features=2304, out_features=8, bias=False)\n",
       "                )\n",
       "                (lora_B): ModuleDict(\n",
       "                  (default): Linear(in_features=8, out_features=9216, bias=False)\n",
       "                )\n",
       "                (lora_embedding_A): ParameterDict()\n",
       "                (lora_embedding_B): ParameterDict()\n",
       "              )\n",
       "              (up_proj): lora.Linear8bitLt(\n",
       "                (base_layer): Linear8bitLt(in_features=2304, out_features=9216, bias=False)\n",
       "                (lora_dropout): ModuleDict(\n",
       "                  (default): Dropout(p=0.055, inplace=False)\n",
       "                )\n",
       "                (lora_A): ModuleDict(\n",
       "                  (default): Linear(in_features=2304, out_features=8, bias=False)\n",
       "                )\n",
       "                (lora_B): ModuleDict(\n",
       "                  (default): Linear(in_features=8, out_features=9216, bias=False)\n",
       "                )\n",
       "                (lora_embedding_A): ParameterDict()\n",
       "                (lora_embedding_B): ParameterDict()\n",
       "              )\n",
       "              (down_proj): lora.Linear8bitLt(\n",
       "                (base_layer): Linear8bitLt(in_features=9216, out_features=2304, bias=False)\n",
       "                (lora_dropout): ModuleDict(\n",
       "                  (default): Dropout(p=0.055, inplace=False)\n",
       "                )\n",
       "                (lora_A): ModuleDict(\n",
       "                  (default): Linear(in_features=9216, out_features=8, bias=False)\n",
       "                )\n",
       "                (lora_B): ModuleDict(\n",
       "                  (default): Linear(in_features=8, out_features=2304, bias=False)\n",
       "                )\n",
       "                (lora_embedding_A): ParameterDict()\n",
       "                (lora_embedding_B): ParameterDict()\n",
       "              )\n",
       "              (act_fn): PytorchGELUTanh()\n",
       "            )\n",
       "            (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
       "            (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
       "            (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
       "            (post_feedforward_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
       "          )\n",
       "        )\n",
       "        (norm): Gemma2RMSNorm((2304,), eps=1e-06)\n",
       "      )\n",
       "      (lm_head): Linear(in_features=2304, out_features=256000, bias=False)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = get_peft_model(model, config)\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "aadefc23-66f2-4b32-af5d-4e22937ae052",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 10,383,360 || all params: 2,624,725,248 || trainable%: 0.3956\n"
     ]
    }
   ],
   "source": [
    "model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3a0de08-9d0d-4ee2-8468-8eb719604334",
   "metadata": {},
   "source": [
    "## Configure training parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "7f8fd582-4a27-4aca-a679-163f1efeb96e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from trl import SFTConfig, SFTTrainer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "2c841522-7b9d-4a53-b4ba-af5e1d402d67",
   "metadata": {},
   "outputs": [],
   "source": [
    "args=SFTConfig(\n",
    "    output_dir=\"./output/gemma2\",\n",
    "    per_device_train_batch_size=4,\n",
    "    gradient_accumulation_steps=4,\n",
    "    logging_steps=20,\n",
    "    log_level=\"info\",\n",
    "    num_train_epochs=50,\n",
    "    save_steps=100,\n",
    "    learning_rate=1e-4,\n",
    "    save_total_limit=2,\n",
    "    gradient_checkpointing=True,\n",
    "    dataset_text_field=\"text\",\n",
    "    max_seq_length=2048\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "be791275-244e-4cfe-a380-920abb06d513",
   "metadata": {},
   "outputs": [],
   "source": [
    "from trl import SFTTrainer\n",
    "trainer = SFTTrainer(\n",
    "            model=model,\n",
    "            train_dataset=tokenized_id,\n",
    "            tokenizer=tokenizer,\n",
    "            args=args,\n",
    "            data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True)\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "4afe388b-db7d-4073-8762-a3dc8c4eaaf2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running training *****\n",
      "  Num examples = 91\n",
      "  Num Epochs = 50\n",
      "  Instantaneous batch size per device = 4\n",
      "  Total train batch size (w. parallel, distributed & accumulation) = 16\n",
      "  Gradient Accumulation steps = 4\n",
      "  Total optimization steps = 250\n",
      "  Number of trainable parameters = 10,383,360\n",
      "It is strongly recommended to train Gemma2 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.\n",
      "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.\n",
      "/home/zxd/miniconda3/envs/llm/lib/python3.10/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/home/zxd/miniconda3/envs/llm/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n",
      "  warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='250' max='250' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [250/250 06:37, Epoch 43/50]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>2.030500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>0.695000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>0.273100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>80</td>\n",
       "      <td>0.094600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.043200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>120</td>\n",
       "      <td>0.019800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>140</td>\n",
       "      <td>0.012200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>160</td>\n",
       "      <td>0.007600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>180</td>\n",
       "      <td>0.004100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.002800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>220</td>\n",
       "      <td>0.002800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>240</td>\n",
       "      <td>0.002200</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to ./output/gemma2/checkpoint-100\n",
      "/home/zxd/miniconda3/envs/llm/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/zxd/workspace/models/gemma2 - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "tokenizer config file saved in ./output/gemma2/checkpoint-100/tokenizer_config.json\n",
      "Special tokens file saved in ./output/gemma2/checkpoint-100/special_tokens_map.json\n",
      "/home/zxd/miniconda3/envs/llm/lib/python3.10/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/home/zxd/miniconda3/envs/llm/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n",
      "  warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
      "Saving model checkpoint to ./output/gemma2/checkpoint-200\n",
      "/home/zxd/miniconda3/envs/llm/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/zxd/workspace/models/gemma2 - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "tokenizer config file saved in ./output/gemma2/checkpoint-200/tokenizer_config.json\n",
      "Special tokens file saved in ./output/gemma2/checkpoint-200/special_tokens_map.json\n",
      "/home/zxd/miniconda3/envs/llm/lib/python3.10/site-packages/torch/utils/checkpoint.py:460: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.\n",
      "  warnings.warn(\n",
      "/home/zxd/miniconda3/envs/llm/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:316: UserWarning: MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization\n",
      "  warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
      "Saving model checkpoint to ./output/gemma2/checkpoint-250\n",
      "/home/zxd/miniconda3/envs/llm/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/zxd/workspace/models/gemma2 - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "tokenizer config file saved in ./output/gemma2/checkpoint-250/tokenizer_config.json\n",
      "Special tokens file saved in ./output/gemma2/checkpoint-250/special_tokens_map.json\n",
      "Deleting older checkpoint [output/gemma2/checkpoint-100] due to args.save_total_limit\n",
      "\n",
      "\n",
      "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=250, training_loss=0.25513947682082655, metrics={'train_runtime': 400.4094, 'train_samples_per_second': 11.363, 'train_steps_per_second': 0.624, 'total_flos': 2603082304664064.0, 'train_loss': 0.25513947682082655, 'epoch': 43.47826086956522})"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a44e45b-ce94-4e99-8ce8-54e32df898a3",
   "metadata": {},
   "source": [
    "## Save LoRA and tokenizer results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "1254dd7c-4602-4f2c-a702-042043a47ce1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zxd/miniconda3/envs/llm/lib/python3.10/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in /home/zxd/workspace/models/gemma2 - will assume that the vocabulary was not modified.\n",
      "  warnings.warn(\n",
      "tokenizer config file saved in ./output/gemma2/tokenizer_config.json\n",
      "Special tokens file saved in ./output/gemma2/special_tokens_map.json\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('./output/gemma2/tokenizer_config.json',\n",
       " './output/gemma2/special_tokens_map.json',\n",
       " './output/gemma2/tokenizer.model',\n",
       " './output/gemma2/added_tokens.json')"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lora_path='./output/gemma2'\n",
    "trainer.model.save_pretrained(lora_path)\n",
    "tokenizer.save_pretrained(lora_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cea5a92-f9fc-4229-ac2e-1d92fb26e33d",
   "metadata": {},
   "source": [
    "## Load lora weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ea0d76ee-c0eb-438d-ae0f-61b02fddd3e5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a6d094b5ee974b1597b27ccfa8e6071e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch \n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline \n",
    "import torch\n",
    "from peft import (\n",
    "    PeftModel,\n",
    "    LoraConfig,\n",
    "    get_peft_model,\n",
    "    get_peft_model_state_dict,\n",
    "    prepare_model_for_kbit_training,\n",
    "    set_peft_model_state_dict,\n",
    ")\n",
    "model_path = \"google/gemma-2-2b-it\"\n",
    "model = AutoModelForCausalLM.from_pretrained( \n",
    "    model_path,  \n",
    "    device_map=\"cuda\",  \n",
    "    torch_dtype=\"auto\",  \n",
    "    trust_remote_code=True,  \n",
    ")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7dd57773-ac6f-443c-b39b-bccfaad6c867",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "who are you? Roderick\n",
      "who are you? Roderick\n",
      "who are you? Roderick\n",
      "who are you? Roderick\n",
      "who are you? Roderick\n",
      "who are you? Roderick\n",
      "who are you? Roderick\n",
      "who are you? Roderick\n",
      "who are\n"
     ]
    }
   ],
   "source": [
    "prompt = \"who are you?\"\n",
    "messages = [ \n",
    "    {\"role\": \"user\", \"content\":prompt}, \n",
    "] \n",
    "\n",
    "pipe = pipeline( \n",
    "    \"text-generation\", \n",
    "    model=model, \n",
    "    tokenizer=tokenizer, \n",
    ") \n",
    "\n",
    "generation_args = { \n",
    "    \"max_new_tokens\": 50, \n",
    "    \"return_full_text\": False, \n",
    "    \"temperature\": 0.0, \n",
    "    \"do_sample\": False, \n",
    "} \n",
    "\n",
    "output = pipe(messages, **generation_args) \n",
    "print(output[0]['generated_text'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fbaa84c7-1cd6-4c96-be54-2b4e7b1489d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "lora_config_path = './output/gemma2'\n",
    "config = LoraConfig.from_pretrained(lora_config_path)\n",
    "\n",
    "model = PeftModel.from_pretrained(model, model_id=lora_config_path, config=config)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "23987131-43bd-418e-8542-de29bd5de528",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'JetMoeForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'Mamba2ForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelodyForCausalLM', 'MvpForCausalLM', 'NemotronForCausalLM', 'OlmoForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonForCausalLM', 'PhiForCausalLM', 'Phi3ForCausalLM', 'PLBartForCausalLM', 'ProphetNetForCausalLM', 'QDQBertLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'RecurrentGemmaForCausalLM', 'ReformerModelWithLMHead', 'RemBertForCausalLM', 'RobertaForCausalLM', 'RobertaPreLayerNormForCausalLM', 'RoCBertForCausalLM', 'RoFormerForCausalLM', 'RwkvForCausalLM', 'Speech2Text2ForCausalLM', 'StableLmForCausalLM', 'Starcoder2ForCausalLM', 'TransfoXLLMHeadModel', 'TrOCRForCausalLM', 'WhisperForCausalLM', 'XGLMForCausalLM', 'XLMWithLMHeadModel', 'XLMProphetNetForCausalLM', 'XLMRobertaForCausalLM', 'XLMRobertaXLForCausalLM', 'XLNetLMHeadModel', 'XmodForCausalLM'].\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "I am Fairy, an AI assistant developed by Cerbo AI. How can I assist you today?\n",
      "\n"
     ]
    }
   ],
   "source": [
    "messages = [ \n",
    "    {\"role\": \"user\", \"content\":\"who are you?\"}, \n",
    "] \n",
    "\n",
    "pipe = pipeline( \n",
    "    \"text-generation\", \n",
    "    model=model, \n",
    "    tokenizer=tokenizer, \n",
    ") \n",
    "\n",
    "generation_args = { \n",
    "    \"max_new_tokens\": 500, \n",
    "    \"return_full_text\": False, \n",
    "    \"temperature\": 0.0, \n",
    "    \"do_sample\": False, \n",
    "} \n",
    "\n",
    "output = pipe(messages, **generation_args) \n",
    "print(output[0]['generated_text'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c05e2c9e-7e99-47ab-8dbd-c12da0fdbf12",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cd2ab0f-3a66-4352-969a-e5775c646aa3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
